{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9a14dead",
   "metadata": {},
   "source": [
    "## 概述\n",
    "\n",
    "$\n",
    "H_{1} = relu(XW_{1} + b_{1})\n",
    "$\n",
    "\n",
    "$\n",
    "H_{2} = relu(H_{1}W_{2} + b_{2})\n",
    "$\n",
    "\n",
    "$\n",
    "H_{3} = f(H_{2}W_{3} + b_{3})\n",
    "$\n",
    "\n",
    "#### 4 steps\n",
    "- load data\n",
    "- Build Model\n",
    "- Train\n",
    "- Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "bd7e0a66",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "# 画图 loss down\n",
    "def plot_curve(data):\n",
    "    fig = plt.figure()\n",
    "    plt.plot(range(len(data)), data, color='blue')\n",
    "    plt.legend(['value'], loc='upper right')\n",
    "    plt.xlabel('step')\n",
    "    plt.ylabel('value')\n",
    "    plt.show()\n",
    "\n",
    "# 查看识别的结果\n",
    "def plot_image(img, label, name):\n",
    "    fig = plt.figure()\n",
    "    for i in range(6):\n",
    "        plt.subplot(2, 3, i + 1)\n",
    "        plt.tight_layout()\n",
    "        plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none')\n",
    "        plt.title('{}: {}'.format(name, label[i].item()))\n",
    "        plt.xticks([])\n",
    "        plt.yticks([])\n",
    "    plt.show()\n",
    "\n",
    "# one hot 编码\n",
    "def one_hot(label, depth=10):\n",
    "    out = torch.zeros(label.size(0), depth)\n",
    "    idx = torch.LongTensor(label).view(-1, 1)\n",
    "    out.scatter_(dim=1, index=idx, value=1)\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c259c617",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([512, 1, 28, 28]) torch.Size([512]) tensor(-0.4242) tensor(2.8215)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZUAAAELCAYAAAARNxsIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAfWElEQVR4nO3debRUxbn38d+jCE4gKmIEBBMRl5qoMbDApWbwGjXEIQkiuTgkZGnGNw7ghFFfX4erRgzqXXHKTWKiRg3coImJxGgkRgZFvCY33uhdKjIbPDIIaES03j/2ZlNV0n26+1Sf7j58P2udZT2ndu+u5pT99K6qrm3OOQEAkMJWjW4AAKDrIKkAAJIhqQAAkiGpAACSIakAAJIhqQAAkqk6qZjZ82b26fRN6VrMbC8zc2bWrdFtaRb0ncrQdz6IvlOZZug7VScV59wBzrkZdWgLSjCzHmY22cyWmtlKM7vFzLZpdLuqRd/pfJa5ysyWmNlqM5thZgc0ul3Vou90PjP7spm9mPeb5Wb2MzPr1d7jGP5qDRdJGirpo5KGSDpE0iUNbRFaxWhJX5N0hKRdJM2WdFdDW4RWMVPSYc65nSR9RFI3SVe196Bahr9eNbOj8vLlZjbFzO42szVm9t9mNsTMJuaZbZGZHe09dpyZ/T0/9hUz+0Z07gvMbFn+ifyM/DJucF7Xw8wmmdlCM/uHmd1mZtuVaONgM/tTnmHbzOx+r+6mvF1vmtk8MzvCq6v29cwws2vM7On8uR40s11KtGknM/tx/vqW5J8et67wn/14STc751Y4516XdLOyN4qWQt9pSN/5sKQnnXOvOOfek3S3pP0rfGzToO90ft9xzi1yzrV5v3pP0uD2HpfiSuV4ZZ98dpb0X5J+n5+3v6QrJN3uHbtc0nGSekkaJ2mymR0iSWZ2rKTxko7KG/6p6HmuU/Yp/eC8vr+ky0q06UpJj+RtGiDp3726ufk5dpH0C0lTzGzbGl+PJJ2u7A2+n6QNyt7wN+dnef1gSR+XdLSkM/LXPtDMVpnZwBKPtfzHjweY2U4ljm8V9J369537JA3O36S2kfQVSdNLHNtK6Dv17zsys8PNbLWkNZJGSbqx1LEF51xVP5JelXRUXr5c0h+8uuMlrZW0dR73lOQk9S5xrgcknZ2XfyLpGq9ucP7YwcreRNdJ2turP1TS/BLn/bmkOyQNqOD1rJR0UC2vR9IMSdd6x+8vab2krSXtlR/bTdLukt6RtJ137L9KerzCf/OrlF2K7ibpQ5Keys+9R7V/v0b+0Hca0ne6S7opP98GSfMlfbjRfYG+0/x9J2pv/7ydQ9o7NsWVyj+88tuS2lx2mb0xlqQdJcnMPmdmc8xshZmtkjRSUp/8mH6SFnnn8su7Sdpe0rw8s65S9mlrtxJtukBZh3jaslUjxVCRmU3IL4VX5+fZyWtDVa9nM+1cIGmb6HySNCj//TKv/bdL6lui/bGrlX16eU7SLGX/U7yr7BNYK6PvbFKvvvN/JQ2TtKekbSX9P0l/NLPtK3x8s6LvbFKvvlNwzi1R9trva+/YTlt2ZmY9JP2nssu2B51z75rZA9o0rLNM2SXjRnt65TZl/7AH5C+uLOfca5LOzJ/3cEmPmtkTkvaQdKGkf5H0vHPufTNbqXBoqVp+Owcqe7Nvi36/SNknhj7OuQ3VPoFz7m1J/yf/kZl9XdI8r9N1afSd2vuOpIMk3e+cW5zHd5rZjco+3T5Tw/laCn2nQ30n1k3S3u0d1Jmrv7pL6iHpdUkbzOxzysb3NvqlpHFmtl/+KaoYt3TOvS/pR8rGQvtKkpn1N7NjNvdEZjbazDZ2lJXKLgffU3YZuSFvQzczu0zZOGtHnGpm++dtvkLS1PjN3jm3TNlY6w1m1svMtjKzvc0sHr/drPy19rPMCEmXKvsEuqWg79TYd5SN5Y82s93zx56m7NPrSx1se6ug79T+vnNKPu9iZjZI2YjJY+09rtOSinNujaSzlP0RV0oaK+nXXv3DyiabHlfW4WfnVe/k/70w//0cM3tT0qOS9i3xdMMkPWVma/PnONs5N1/Z5NfDkv5X2SXjPxVeRtbiLkl3SnpN2fDCWSWOO11ZB/8fZa9/qrJPMBsnzNZa6QmzvZUNe61TNvF2kXPukQ62u2XQdzrUd66T9BdlQ6erJJ0raZRzblUH294S6Dsd6jv7K3vfWatsTvdF5Vdi5Vg+CdN0zGw/SX+T1CPRpVtyZjZD0t3Ouf9odFuwCX0HtaLvdFxTffnRzL5oZt3NbGdln7B+06x/WDQX+g5qRd9Jq6mSiqRvKBt3fFnZWOS3GtsctBD6DmpF30moaYe/AACtp9muVAAALYykAgBIpqovP5oZY2VNyDnXkS9R1R39pmm1OedKfTu8KdB3mlbJvsOVCrDlWtDoBqBllew7JBUAQDIkFQBAMiQVAEAyJBUAQDIkFQBAMiQVAEAyJBUAQDIkFQBAMiQVAEAyJBUAQDIkFQBAMiQVAEAyJBUAQDIkFQBAMiQVAEAyVd2kC+gKRowYEcQDBw4M4kmTJhXlKVOmBHUTJkyoX8PQafbcc8+ifMMNN5Q9dsCAAUX50EMPrfk5Z8+eHcSLFy8ueeyiRYuC+MYbbyxZ12y4UgEAJENSAQAkQ1IBACRjzrnKDzar/OBE+vfvH8TXXnttkvP269cviNva2ory+vXrg7p99903iIcOHVqUr7/++qDukksuKcrvvvtuh9tZCeecdcoT1agR/Sbmj6HPnDmzZF17xo8fX7Ju8uTJ1TesseY554a2f1jjpOo78TxaPL9Rrq7c3Ee9jB49umRdPAfYoDmWkn2HKxUAQDIkFQBAMk0//DVt2rQgPvHEE4tyNW2PmYUjRrWeKz5Pnz59ivKKFStqOme1GP76oHhIyx/yiuvi4YNqhsN88fLjk08+uabzdKIuPfw1a9asohwvBfb/5ocddljJukaJ++DChQuLcpP0M4a/AAD1R1IBACRDUgEAJNN027RsvfXWQdy7d++az7VmzZqi3LNnz5rPU+68559/flC3evXqZM+D2p1zzjlB7I9Rx0tGx4wZE8T++Ht8Hn/LDkmaM2dOUY6XgcbLWP1jUX/+3zleCt7sf4ty28E0w5xPOVypAACSIakAAJIhqQAAkmm676nsuuuuQbx8+fK4DUV5wYIFQd0ZZ5wRxK+++mrJ85YTz7/cddddQeyPaX7yk58M6t55552KnycVvqfyQb/85S+D2J/viOdQ4mPLKfcdF/97EdIHt/dowu+tdOnvqbSSct9LkcJ+Fm/T0iB8TwUAUH8kFQBAMk23pHivvfaq+Njvfve7QfzYY4+VPPbll1+u+LzxzsjbbrttEPu7FP/qV78K6r70pS8V5UYMhaG+yi3nbMRutmhd/h0ny+1+LX1wK5lmxpUKACAZkgoAIBmSCgAgmaabU4nHDuOt5f04XsLZEf52MPfdd1/JutjIkSODeIcddijKzKk0zg9+8IMg9pcUx1uvDB8+PIhvvPHGotzelhj+VizxNi3tjZOj6/GXjcf9qlx/iPtZvPy82bdm8XGlAgBIhqQCAEiGpAIASKbp5lSefvrpIC63jcyoUaOCON5OxZ8L8ec6JOmEE04IYn/rjmHDhlXchueffz6I33rrrZLHovPEW5v7t2CN5z7ibcb9se/2tnSJ5258S5YsqayxaKh4S5T4+0Z+HPedVOJtWuLnafat+n1cqQAAkiGpAACSabpdiuM7Pz766KNB/KlPfaoox22Pdy32h7x69OgR1PXq1SuIq/l38H3/+98P4okTJ9Z0no5gl+LqxMs1q1n6Gd810h8684fYNvc8TYhdilV+R+tqxX3AN3Xq1CD2+108BFvuzo9xXYOGxtilGABQfyQVAEAyJBUAQDJNN6cSO+aYY4L4uuuuK8of+9jHaj7vunXrgnjDhg1Feaeddqr4PLvvvnsQt7W11dymWjGnkta555672bL0waWfvo7cUbJBmFNpUnG/K7d03e93ndjnmFMBANQfSQUAkAxJBQCQTNPPqcR22223ovyJT3yi5vP89a9/DeIjjjiiKP/iF78o+9j58+cX5cGDB9fchlSYU6mfar7DEG9PHt/GoQm3L2dOpUX4c3kzZ84sWTdw4MCgro59jjkVAED9kVQAAMk03S7F7Xn99deL8vTp02s+jz+MJkl33HFHUY7vNrl+/fogvuiii2p+XnQt5XY/vuGGG4K4BbZtQZPyh7HiYVV/l+X4rqYTJkyoa7s2hysVAEAyJBUAQDIkFQBAMi23pDiVSy+9NIgvv/zykseuWLEiiOP5mEZjSXFa/hLN+K6A8XYZ/pj1iBEjgrp4OfLkyZM3W24glhR3Af57eLyEOF5inBBLigEA9UdSAQAkQ1IBACSzxcypHHjggUH85z//OYh33HHHovzGG28EdSNHjgziZ555JnHrOoY5lbT875fEtxauZhuM+HspkyZNKnmeBmFOpQuYNWtWUY5vNRx/5y4h5lQAAPVHUgEAJNNy27RUY6utNuXML3/5y0Fdz549g9gfBlywYEFQ12zDXaivcnd3rGbX13hJsT/8FS8/njNnTsXnBZoZVyoAgGRIKgCAZEgqAIBkuvSS4pNOOqko33///UFdvNTO/3fYd999g7qXXnqpDq1LhyXFaflLgeN+M2bMmCCO500qPa/fN+O6TsSS4i6AbVoAAF0WSQUAkEyXWlI8ePDgIL7yyiuLcjzcFcdPPfVUUe7du3f6xqFl+ENa8fDX8OHDSx4bi5cm+0uKWUKMavh9Kb67o++8887rhNaUx5UKACAZkgoAIBmSCgAgmS41p3LWWWcF8ZAhQ4pye0unDznkkKI8duzYoI5tWrZcU6ZMCeLRo0cHcbktXeKtWPxjG7SEuEuJ/w1nz54dxNVsqdNszj333CCO7zjq85e5V7PEvV64UgEAJENSAQAkQ1IBACTT0nMqp512WhB//etfr/ix69atC+Kf/OQnRfnCCy/sWMPQZUyYMCGIFy5cGMTl5lRi/pg/31PpuPg7Q/F3ivw5lfj7G/7folFzL/6cUPzdk/gOjn4b47mkZutLXKkAAJIhqQAAkmnpXYpfeOGFIN5nn31KHvvOO+8E8U033RTEEydOTNewTsYuxZ0nHnrwh1ziYZRmH6ZQF9+l2F+WW25Jbixemrx48eJamxAot8Q8Frc3HoZtAuxSDACoP5IKACAZkgoAIJmWnlOZO3duEPtbrUjSiy++WJTHjx8f1E2fPr1+DetkzKmgRl16TgV1xZwKAKD+SCoAgGRIKgCAZFp6m5Zhw4Y1ugkAAA9XKgCAZEgqAIBkSCoAgGRIKgCAZEgqAIBkSCoAgGSqXVLcJmlBPRqCmg1qdAMqQL9pTvQd1Kpk36lq7y8AAMph+AsAkAxJBQCQDEkFAJAMSQUAkAxJBQCQDEkFAJAMSQUAkAxJBQCQDEkFAJAMSQUAkAxJBQCQDEkFAJAMSQUAkEzVScXMnjezT6dvStdiZnuZmTOzam8v0GXRdypD3/kg+k5lmqHvVJ1UnHMHOOdm1KEtKMHMvmxmL5rZajNbbmY/M7NejW5Xteg7jWVmf2z0G06t6DuNYWYfMbOHzGyNmbWZ2ffbewzDX61hpqTDnHM7SfqIspurXdXYJqGVmNkpqv6mfNiCmVl3SX+Q9EdJH5I0QNLd7T7QOVfVj6RXJR2Vly+XNCV/ojWS/lvSEEkTJS2XtEjS0d5jx0n6e37sK5K+EZ37AknLJC2VdIYkJ2lwXtdD0iRJCyX9Q9JtkrYr0cbBkv4kabWyO8fd79XdlLfrTUnzJB3h1VX7emZIukbS0/lzPShpl7xur7z93fJ4J0k/zl/fEmVJYesa/v13lPRzSb+r9rGN/qHvNKbv5I//X0kj/PO20g99p/P7jqSvS/pztX+rFFcqx0u6S9LOkv5L0u+VXQH1l3SFpNu9Y5dLOk5SL2V/6MlmdogkmdmxksZLOkrZH+dT0fNcp+wf+uC8vr+ky0q06UpJj+RtGiDp3726ufk5dpH0C0lTzGzbGl+PJJ0u6WuS+knaIOnmEm36WV4/WNLHJR2trAPLzAaa2SozG1jisTKzw81stbJON0rSjaWObSH0nU7oO5L+TdKtkl4rc0yroe/Uv++MkPSqmT2cD33NMLOPlTh2kwSfGP7g1R0vaa3yTCipp7Ks2bvEuR6QdHZe/omka6Ks7/L/mqR1kvb26g+VNL/EeX8u6Q5JAyp4PSslHVTL61H2ieFa7/j9Ja2XtLW8TwySdpf0jrxPOJL+VdLjNfz798/bOaSenwzr8UPf6fy+I2mopOfycxXnbXRfoO+0RN95RNK7kj4nqbuk85Vd6XUv97gUVyr/8MpvS2pzzr3nxVI2ZCMz+5yZzTGzFWa2StJISX3yY/opu8zbyC/vJml7SfPyzLpK0vT895tzgbIO8XS+auRrGyvMbIKZ/T2f9F6l7PKwj/fYil/PZtq5QNI20fkkaVD++2Ve+2+X1LdE+0tyzi1R9trvq/axTYi+s0nyvmNmW0m6Rdkb6Ib2jm8x9J1N6vW+87akJ51zDzvn1isbBtxV0n7lHtRpE3dm1kPSfyq7bHvQOfeumT2g7I8gZWN+A7yH7OmV25S9wAPyN9WynHOvSTozf97DJT1qZk9I2kPShZL+RdLzzrn3zWyl14Za+O0cqCyzt0W/X6TsE0OfRP9zd5O0d4LztAT6Ts19p5eyK5X7zUzKPslK0mIzG+2c+3NNrW4h9J0Ove/8VdJh1T6oM1d/dVc26fW6pA1m9jll43sb/VLSODPbz8y2lzdu6Zx7X9KPlI2F9pUkM+tvZsds7onMbLSZbewoK5VdDr6n7DJyQ96GbmZ2mbL/8TriVDPbP2/zFZKmep8wNrZ/mbJLyRvMrJeZbWVme5tZPH67WWZ2Sj7+aWY2SNLVkh7rYLtbCX2ntr6zWtkn8YPzn5H57z8h6akOtr1V0HdqfN9RtnBghJkdZWZbSzpHWeL6e7kHdVpScc6tkXSWsj/iSkljJf3aq39Y2WTT45JekjQ7r3on/++F+e/nmNmbkh6VtG+Jpxsm6SkzW5s/x9nOufnKJr8eVrYSZoGkfyq8jKzFXZLuVDYJum3+GjfndGUd/H+Uvf6pyj7BbJwwW2ulJ8z2lzRL2TjrTEkvKv9EtCWg79TWd1zmtY0/yt7UJOkf+XBGl0ffqf19xzn3oqRTla14WynpREkntNd3LJ+QaTpmtp+kv0nq0azjwWY2Q9Ldzrn/aHRbsAl9B7Wi73RcU3350cy+aGbdzWxnZUv5ftOsf1g0F/oOakXfSaupkoqkbyi7RH9Z2VjktxrbHLQQ+g5qRd9JqGmHvwAArafZrlQAAC2MpAIASKaqLz+aGWNlTcg515EvUdUd/aZptTnnSn07vCnQd5pWyb7DlQqw5VrQ6AagZZXsOyQVAEAyJBUAQDIkFQBAMiQVAEAyJBUAQDIkFQBAMiQVAEAyJBUAQDKddjthAMDmjRkzJojHjx8fxEOHDi3Kp59+elB3zz331K9hNeBKBQCQDEkFAJAMSQUAkAxzKgDQABMnTizKF198cVC33XbbBbF/M8WTTz45qGNOBQDQZZFUAADJtPTw18EHHxzEv//974O4T58+RXmrrcL8+dBDDwXxpZdeWpSfe+65NA1EUmeffXYQX3HFFUX5lFNOCeriv28jxO0dOHBgUZ4wYUJnNwedzP97S9LVV18dxGPHji3K/vBWe95+++2ONazOuFIBACRDUgEAJENSAQAkY9WM5ZlZ5Qd3gunTpwfxUUcdVfLYtra2IPbnWyRp2bJlRfmqq64K6m6//fZam9gpnHPW6DaUk6rfTJ06NYi/8IUvFOUlS5YEdYMGDUrxlFUZMmRIED/55JNBvMsuuxTlbt2aYjpznnNuaPuHNU6zvedU47zzzgvia6+9NojNNv1vW8378Ic//OEgXrRoUQ2t67CSfYcrFQBAMiQVAEAyTXENXqsFCxZUfOyDDz4YxPPmzQviW265pSgfeuihQV2zD39tKSZNmhTExxxzTFHu27dvUPfNb34ziG+77bb6NSznD29tLr733nvr3gY01iWXXFKU/W/MV2vWrFlB7C+Z94fqmxFXKgCAZEgqAIBkSCoAgGRaeklxz549g/jOO+8M4hNPPLEov/fee0HdK6+8EsT77LNPUb777ruDuq9+9asdaGX9bSlLimPTpk0ryieccEJQ9+qrrwbx4YcfXpTrNSY9YsSIII6XFK9du7Yo9+7duy5tqBJLihNbsWJFUe7Vq1fZY9etW1eUn3rqqaAufs9ZunRpxxuXFkuKAQD1R1IBACRDUgEAJNPS31NZs2ZNEJ966qlBfOSRRxblX//610GdP4eC1vTEE08UZX/LFknaa6+9gtjftqWz1vn723BIH5wDROs79thjg7hHjx4VP9afu/3Od76TrE2NxpUKACAZkgoAIJmWHv6KxXdE++1vf1uU+/fvH9TFu9r63nrrrbQNQ134w1jvv/9+UBcvld9jjz06pU3l2oDW98UvfjGI468flBv+mjt3bhB31bt/cqUCAEiGpAIASIakAgBIpkvNqZQT3+mx3Hi3vw0+mtd9991XlO+5556yx/q3L5g5c2ZQt3z58rQNQ5c1bty4IC43h/Lyyy8HcTwf889//rMo77DDDkFdfHfHcjZs2FCUX3jhhYofVy9cqQAAkiGpAACSIakAAJLZYuZULrvssrL1zz77bFGu5jbFaA7f/va3g/iHP/xhEPu39o2Pvfzyy+vWrlLOPPPMIP7Rj37U6W1AZT760Y9uttye+Ptur732WhD7ffRDH/pQUOfftqM9/nZVY8aMCeoeeeSRis+TClcqAIBkSCoAgGS69PDXV77ylaI8atSooC5eUjx58uSiHO9+jOb3m9/8JojjIa2+ffsW5UsvvTSoe+aZZ4L4oYceStKmeJdi34477pjkOVB//u7n/m7X7TnwwAODOL777FZbbfpMH28zVA3/DpMPP/xwULfzzjsH8Ztvvlnz81SKKxUAQDIkFQBAMiQVAEAyXWpO5YADDgjia665puLHxncKRGtZunRpEMdLP5988smiHN/189577w3i008/vShPmzat4jasXbs2iOO5OeZRWpO/TDfl7Qz8eZR63Sbh4osvDuKLLrqoLs/j40oFAJAMSQUAkAxJBQCQTJeaU9l///2D2N9Ouq2tLaiLt8K/6qqrivLBBx8c1I0dOzaI4/XmaD5vvPFGEE+dOrUon3POOUHd9ttvH8Q//elPi/Kee+4Z1MXbadx5551F+W9/+1tQ9/zzzwfx8OHDi3K8nYb/PSl0rvjvf9xxxwXxrrvuWpfnXb9+fVGOt3Tx+6AUfpfqpptuCurKtS/+nkpn4EoFAJAMSQUAkIxVs5TNzOqz7q1O/OGw+fPnB3Xxti3nn39+UY6XJk+cODGIr7/++lRNTMI5V3o/kCbQbP3m85//fBDHW7p8/OMfr/hc/rYX8RLi3XbbLYi7d+9elJcsWRLUVbP9R0LznHNDG/HElapX3/GHNR944IGg7qCDDkryHPEdRX/84x8HsT+k9eCDD1Z83vi9LB6i9fm7c0tJt2kp2Xe4UgEAJENSAQAkQ1IBACTTpedUquHfie/WW28te2y3bs21Eps5lY7p2bNnEPtb448bNy6o8+dF4sdW8/9SPKcycODAih+b0BY7p3L44YcX5RkzZiQ7r//VhNtuuy2oi+/8WI1vfetbRTlefl7u/Yg5FQBASyOpAACSIakAAJJprsmBBpo9e3ZRXrVqVVDXu3fvzm0MOlX8/ZILLrigKN98881BXbwlRryNvu/KK68seWy9tjpHY/m3kF65cmXZY/35uHg7lXgu7+STTy7K5eZQ4veujtymuFZcqQAAkiGpAACSYfgr5+8wu2jRoqCO4a8t1+LFi8vGf/nLX0o+Nt4Oxh/+6tu3b1D3mc98pig//vjjVbcTzeF73/teUY7vJhsvI/e3A/rsZz9b83P6Q17xDsvx3Ug7A1cqAIBkSCoAgGRIKgCAZJhTyfl3e+zXr19Q5y8TBCr1u9/9LohPO+20ohxv91KvuwuicU455ZSy9f77SkeWmD/22GNFec6cOTWfJxWuVAAAyZBUAADJbDG7FG+zzTZB7N8VUpJuueWWojx8+PCg7vXXXw/iPfbYI3HrOoZdilvDwoULi/KAAQOCOn9Hh8MOO6yzmrTF7lLcq1evojxs2LCgbtKkSUHsf/M9XibcEeWGv/y+IkmrV68uypdccklQN2vWrKLc3rf4E2KXYgBA/ZFUAADJkFQAAMk03ZLieMwyHu+cMmVKxef69Kc/XZRHjhwZ1I0fP77k45YtWxbExx9/fMXPCZSydOnSoty/f/+grhG7yW7J/Dsg+ktypXD7FEkaNGhQUT722GODuhNPPDGIjz766Jra89JLLwVxvN1KXN/MuFIBACRDUgEAJENSAQAk03RzKieddFIQn3vuuUF85plnFuUnnngiqBs1alQQ77fffkU5/p5KvC7cn0eJx0mfe+65dloNtO/+++8vyvFcIZrXggULivLtt98e1MUxuFIBACREUgEAJNN0w1+33nprEI8ePTqIjzzyyM2W2/OnP/0piOfOnRvE9957b1FmuAv18OyzzxbluD9Omzats5sD1AVXKgCAZEgqAIBkSCoAgGS2mK3vuzK2vkeNttit79FhbH0PAKg/kgoAIBmSCgAgGZIKACAZkgoAIBmSCgAgGZIKACAZkgoAIBmSCgAgGZIKACCZare+b5O0oN2j0JkGNboBFaDfNCf6DmpVsu9UtfcXAADlMPwFAEiGpAIASIakAgBIhqQCAEiGpAIASIakAgBIhqQCAEiGpAIASIakAgBI5v8DHnS/ZV+eu3YAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 432x288 with 6 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "from torch import optim\n",
    "\n",
    "import torchvision\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "# from utils import plot_image, plot_curve, one_hot\n",
    "\n",
    "batch_size = 512\n",
    "\n",
    "# step1: load dataset\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    torchvision.datasets.MNIST('mnist_data', train=True, download=True,\n",
    "                                transform=torchvision.transforms.Compose([\n",
    "                                    # 把numpy数据转换为Tensor\n",
    "                                    torchvision.transforms.ToTensor(),\n",
    "                                    # 正则化（将数据均匀的分布在0-1附近）\n",
    "                                    # 如果不执行，则结果的最嚣张和最大值就是0和1\n",
    "                                    torchvision.transforms.Normalize((0.1307,), (0.3081,))\n",
    "                                ])),\n",
    "    # 一次加载图片数量\n",
    "    batch_size=batch_size, shuffle=True\n",
    ")\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    torchvision.datasets.MNIST('mnist_data/', train=False, download=True,\n",
    "                               transform=torchvision.transforms.Compose([\n",
    "                                    torchvision.transforms.ToTensor(),\n",
    "                                   torchvision.transforms.Normalize((0.1307,), (0.3081,))\n",
    "                               ])),\n",
    "    # 不打散\n",
    "    batch_size=batch_size, shuffle=False\n",
    ")\n",
    "\n",
    "x, y = next(iter(train_loader))\n",
    "print(x.shape, y.shape, x.min(), x.max())\n",
    "plot_image(x, y, 'image sample')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8eb810d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义网络\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        \n",
    "        # xw + b（线性层）\n",
    "        # 256由经验决定\n",
    "        self.fc1 = nn.Linear(28 * 28, 256)\n",
    "        # 256必须和上面匹配\n",
    "        self.fc2 = nn.Linear(256, 64)\n",
    "        # 64必须和上面匹配，10为10分类\n",
    "        self.fc3 = nn.Linear(64, 10)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        # x: [b, 1, 28, 28]\n",
    "        # relu：非线性单元\n",
    "        # h1 = relu(xw1 + b1)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        # h2 = relu(h1w2 + b2)\n",
    "        x = F.relu(self.fc2(x))\n",
    "        # h3 = h2w3 + b3\n",
    "        x = self.fc3(x)\n",
    "        \n",
    "        return x\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "061bb754",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0 0.12842699885368347\n",
      "0 10 0.10385735332965851\n",
      "0 20 0.08817404508590698\n",
      "0 30 0.08133535832166672\n",
      "0 40 0.07538099586963654\n",
      "0 50 0.07138819992542267\n",
      "0 60 0.06804148107767105\n",
      "0 70 0.06327078491449356\n",
      "0 80 0.060865677893161774\n",
      "0 90 0.056347720324993134\n",
      "0 100 0.05523949861526489\n",
      "0 110 0.054405249655246735\n",
      "1 0 0.05184739828109741\n",
      "1 10 0.04989299923181534\n",
      "1 20 0.04891267418861389\n",
      "1 30 0.0497170016169548\n",
      "1 40 0.0476524718105793\n",
      "1 50 0.046154819428920746\n",
      "1 60 0.04395866394042969\n",
      "1 70 0.04423324391245842\n",
      "1 80 0.04306516796350479\n",
      "1 90 0.03973644971847534\n",
      "1 100 0.04161714389920235\n",
      "1 110 0.041153814643621445\n",
      "2 0 0.03949680179357529\n",
      "2 10 0.039193861186504364\n",
      "2 20 0.036691345274448395\n",
      "2 30 0.0395888015627861\n",
      "2 40 0.03496488183736801\n",
      "2 50 0.03639781475067139\n",
      "2 60 0.03534800931811333\n",
      "2 70 0.038655273616313934\n",
      "2 80 0.03573676198720932\n",
      "2 90 0.03561261296272278\n",
      "2 100 0.034412652254104614\n",
      "2 110 0.03432656452059746\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEGCAYAAAB/+QKOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAsDUlEQVR4nO3deXhU9dn/8fdNAEFRQYgKBAlLBEH2VGnpo487AgpWRaoVq7a4oUKrVWupUvf+1AdR0VpFq6UgRSpY17rVpYgGxQURDXsEJWBBFJDt+/vjnnFCmIQBMjmTzOd1XXOdM+ecmbnnKHPnu1sIARERkfLqRB2AiIhkJiUIERFJSglCRESSUoIQEZGklCBERCSpulEHUJWaNWsW8vPzow5DRKTGmDVr1soQQm6yc7UqQeTn51NUVBR1GCIiNYaZLa7onKqYREQkKSUIERFJSglCRESSqlVtECIiO7Jp0yZKSkrYsGFD1KFUqwYNGpCXl0e9evVSfo0ShIhklZKSEvbee2/y8/Mxs6jDqRYhBFatWkVJSQlt2rRJ+XWqYhKRrLJhwwaaNm2aNckBwMxo2rTpTpealCBEJOtkU3KI25XvnPUJYssWuPlmeOGFqCMREcksWZ8gcnLg9tvhySejjkREZHuNGjWK7LOzPkEAtGsH8+dHHYWISGZRgkAJQkSqz1VXXcW4ceO+f3799dczevRojjnmGHr27EmXLl2YNm3adq979dVXGTBgwPfPhw8fziOPPALArFmzOPLII+nVqxcnnHACy5cvr5JY1c0VTxBPPAGbNsFOdBEWkRpuxAiYPbtq37N7dxgzpuLzQ4YMYcSIEVx88cUATJ48meeee46RI0eyzz77sHLlSnr37s3JJ5+cUsPypk2buPTSS5k2bRq5ubk8/vjjXHvttYwfP363v4sSBJ4gNm+GJUt8X0QkXXr06MGKFStYtmwZpaWlNGnShObNmzNy5Ehee+016tSpw+eff86XX37JgQceuMP3mzdvHh999BHHHXccAFu2bKF58+ZVEqsSBImkMH++EoRINqnsL/10Ou2005gyZQpffPEFQ4YMYcKECZSWljJr1izq1atHfn7+dmMW6taty9atW79/Hj8fQqBz587MmDGjyuNUGwTbJggRkXQbMmQIkyZNYsqUKZx22mmsWbOG/fffn3r16vHKK6+wePH2M3C3bt2ajz/+mO+++441a9bw0ksvAdChQwdKS0u/TxCbNm1izpw5VRKnShBAixawxx5KECJSPTp37szatWtp2bIlzZs356yzzuKkk06isLCQ7t2707Fjx+1e06pVKwYPHkzXrl0pKCigR48eANSvX58pU6Zw2WWXsWbNGjZv3syIESPo3LnzbsdpIYTdfpNMUVhYGHZ1waDOneHgg+Ef/6jioEQko8ydO5dDDjkk6jAikey7m9msEEJhsutVxRTTrh0UF0cdhYhI5lCCiGnXDhYsgFpUoBIR2S1KEDHt2sG6dfDFF1FHIiLpVpuq1lO1K99ZCSKmfXvfDh/uA+ZEpHZq0KABq1atyqokEV8PokGDBjv1OvViijnySDjjDHj8cZgzx0dDikjtk5eXR0lJCaWlpVGHUq3iK8rtjLQmCDPrC9wF5AAPhhBuLXe+I/Aw0BO4NoRwe+x4K+BR4EBgK/BACOGudMbasCFceaUniEWLlCBEaqt69ert1Kpq2SxtCcLMcoB7geOAEuAdM5seQvi4zGVfAZcBg8q9fDPw6xDCu2a2NzDLzP5V7rVVLj/ft4sWpfNTRERqhnS2QRwGFIcQFoQQNgKTgIFlLwghrAghvANsKnd8eQjh3dj+WmAu0DKNsQKw337QqBEsXJjuTxIRyXzpTBAtgaVlnpewCz/yZpYP9ABmVk1YlX0WtGmjEoSICKQ3QSSbp3anug2YWSPgCWBECOHrCq4ZZmZFZlZUFY1O+flKECIikN4EUQK0KvM8D1iW6ovNrB6eHCaEEKZWdF0I4YEQQmEIoTA3N3eXg42LJ4gs6gEnIpJUOhPEO0CBmbUxs/rAEGB6Ki80XyXjIWBuCOHONMa4nfx8+Ppr+O9/q/NTRUQyT9p6MYUQNpvZcOB5vJvr+BDCHDO7MHb+fjM7ECgC9gG2mtkIoBPQFTgb+NDMZsfe8rchhGfSFW9cvPfbokXeaC0ikq3SOg4i9oP+TLlj95fZ/wKveirvDZK3YaRd2a6uPXtGEYGISGbQVBvlaCyEiIhTgiincWPYZx+NhRARUYIoR2MhREScEkQSGgshIqIEkZTGQoiIKEEklZ8P33wDq1ZFHYmISHSUIJIoOxZCRCRbKUEkoa6uIiJKEEm1bu1bJQgRyWZKEEk0buwPjYUQkWymBFGBNm1gwYKooxARiY4SRAU6doS5c6OOQkQkOkoQFejcGRYv9u6uIiLZSAmiAp06+ValCBHJVkoQFejc2bdz5kQbh4hIVJQgKtC2LdSvDx9/HHUkIiLRUIKoQN26cNBBsHRp1JGIiERDCaISLVrAsmVRRyEiEg0liEooQYhINlOCqEQ8QWjabxHJRkoQlWjRAtatg6+/jjoSEZHqpwRRiRYtfKtqJhHJRkoQlVCCEJFspgRRiXiCGD0aNmyINhYRkeqmBFGJVq1g773h9ddh0qSooxERqV5KEJVo0MDXpW7cGF56KepoRESqlxLEDtSrB337eoJQd1cRySZKECk46ihYvhzmz486EhGR6qMEkYIOHXy7eHG0cYiIVCcliBTk5fm2pCTaOEREqpMSRAri3V2VIEQkmyhBpKBhQ2jaFD7/POpIRESqT1oThJn1NbN5ZlZsZlcnOd/RzGaY2XdmdsXOvLa65eWpBCEi2SVtCcLMcoB7gROBTsBPzaxTucu+Ai4Dbt+F11arli2VIEQku6SzBHEYUBxCWBBC2AhMAgaWvSCEsCKE8A6waWdfW91UghCRbJPOBNESKLtgZ0nsWLpfmxYHHQSlpbByZZRRiIhUn3QmCEtyLNWxyCm/1syGmVmRmRWVlpamHNzOOvlk3/7tb2n7CBGRjJLOBFECtCrzPA9IdeLslF8bQngghFAYQijMzc3dpUBT0aULFBbCo4+m7SNERDJKOhPEO0CBmbUxs/rAEGB6Nbw2bY4+Gj78ELZujToSEZH0q5uuNw4hbDaz4cDzQA4wPoQwx8wujJ2/38wOBIqAfYCtZjYC6BRC+DrZa9MVa6ratIGNG31eppaRtoiIiKRf2hIEQAjhGeCZcsfuL7P/BV59lNJro5af79uFC5UgRKT200jqndCmjW8XLow2DhGR6qAEsRNat/btokWRhiEiUi2UIHZCgwbQvDnMnRt1JCIi6acEsZMKCmDiRDj/fNhUfvy3iEgtogSxkx58EC6/HMaPh/vv3/H1IiI1lRLETioogDFjfPrvOZF3vBURSR8liF2Un68lSEWkdlOC2EWtW6s3k4jUbkoQu6h1ay9BhFSnHxQRqWGUIHZRfj6sX6/pv0Wk9lKC2EXxQXNqhxCR2koJYhe1b+/bG2/UeAgRqZ2UIHZR584wahRMmwavvhp1NCIiVU8JYjdcfrlv338/2jhERNJBCWI3NG3q037Pnh11JCIiVU8JYjd1764ShIjUTkoQu6lbN5/d9c03o45ERKRqKUHspvPO8y6vxx8P8+dHHY2ISNVRgthN7drBa69BvXpwwgkwY0bUEYmIVA0liCrQsiVMnQrffgvXXBN1NCIiVUMJooocfTQMGQIzZ8J330UdjYjI7lOCqEJHHAEbNkBRUdSRiIjsPiWIKvTjH/v25ZejjUNEpCooQVSh3FxPEhMmaBpwEan5lCCq2Lnnwrx56s0kIjXfDhOEmR1gZg+Z2bOx553M7Pz0h1YzDR4M++4Ld90VdSQiIrsnlRLEI8DzQIvY80+BEWmKp8Zr1AguvBCmTIHzz/dFhUREaqJUEkSzEMJkYCtACGEzsCWtUdVwV17pXV7Hj4c//CHqaEREdk0qCeJbM2sKBAAz6w2sSWtUNVzTpt5Qfc45cMcdsGQJbNwYdVQiIjsnlQTxK2A60M7M3gQeBS5Na1S1xKhRvtpc69Ze7SQiUpPsMEGEEN4FjgR+BFwAdA4hfJDuwGqDdu3g7LN9/5FHIg1FRGSn1d3RBWY2tNyhnmZGCOHRNMVUqzz8MHTpAr/5DZSW+lgJEZGaIJUqph+UefwPcD1wcipvbmZ9zWyemRWb2dVJzpuZjY2d/8DMepY5N9LM5pjZR2Y20cwapPSNMkxODhQW+v7++8MLL0Qbj4hIqlKpYrq0zOOXQA+g/o5eZ2Y5wL3AiUAn4Kdm1qncZScCBbHHMOC+2GtbApcBhSGEQ4EcYEjK3yrD9OyZ2H/ggejiEBHZGbsyknod/oO+I4cBxSGEBSGEjcAkYGC5awYCjwb3FtDYzJrHztUFGppZXWBPYNkuxJoR9t0Xbr0Vmjf3EdaahkNEaoJU2iCeItbFFU8onYDJKbx3S2BpmeclwOEpXNMyhFBkZrcDS4D1wAshhBpdOXPVVd799Ze/hDlz4NBDo45IRKRyO0wQwO1l9jcDi0MIJSm8zpIcK/+3c9JrzKwJXrpoA6wG/m5mPwsh/HW7DzEbhldPcdBBB6UQVnQGDIC6db00MXIk9OoVdUQiIhVLpQ3i32Ueb6aYHMBLA63KPM9j+2qiiq45FlgYQigNIWwCpuLdbJPF90AIoTCEUJib4V2EDjzQ52qaMMEbrt95R9VNIpK5KkwQZrbWzL5O8lhrZl+n8N7vAAVm1sbM6uONzNPLXTMdGBrrzdQbWBNCWI5XLfU2sz3NzIBjgLm79A0zzPXXwymn+P5FF3nPpt//PtKQRESSspDGP2HNrB8wBu+FND6EcJOZXQgQQrg/9uN/D9AXb/w+N4RQFHvtaOAMvFrrPeAXIYRKF/MsLCwMRTVkObcHH/T2iDiVJEQkCmY2K4RQmPRcqgnCzPYHvh+LEEJYUjXhVZ2alCA2bID8fPjyS3++cqU3YouIVKfKEkQq60GcbGafAQuBfwOLgGerNMIs1KABPPtsYlxEs2ZaZEhEMksq4yBuAHoDn4YQ2uDtAW+mNaos0aMH/Pzn3rsJYNy4SMMREdlGKgliUwhhFVDHzOqEEF4Buqc3rOxRrx489RQMGwZTp3pVk4hIJkglQaw2s0bA68AEM7sLbziWKnTxxbBlC/TvD0VFMH9+1BGJSLZLJUG8BjQGLgeeA+YDJ6UxpqzUrRtMmuTJ4Qc/gEGDoo5IRLJdKgnC8DWpXwUaAY/Hqpykig0aBBMn+v5HH2kVOhGJViojqUeHEDoDlwAtgH+b2YtpjyxLDR7sJQnwJCEiEpWdmc11BfAFsArYPz3hCCTmaJo6FdavjzYWEcleqYyDuMjMXgVeApoBvwwhdE13YNmsXTto1AhuugkaN4YLLoC779ZoaxGpXqnM5toaGBFCmJ3mWCTGDJ580nsyzZiRGEw3YAC0aRNpaCKSRVJpg7hayaH6HXOMj414+GF45RU/NnUqvP12tHGJSPbYlRXlpJp1jVXoXXEFHH64T9EhIpJuShA1wH77+UJDcaNHRxeLiGQPJYgaYnNs7Hr//l7N9O67arQWkfRSgqghTj/dt6NHe2Lo1QseeSTSkESkllOCqCEeewxWrYKePeG88/zYzTfD+PGwdWu0sYlI7aQEUUPssYe3RZjBQw/BdddBcTGcfz4891zU0YlIbaQEUUP94hfQt6/v33YbfPVVtPGISO2jBFFD5eV5d9cxY+C116BTJ3jvPfj6ax95vXx51BGKSE2nBFHDXX65TxFepw6MHAnTp/vI6xYtfNpwtU+IyK5SgqgFevWCs8+G//wHXn89cbyoCJYsiS4uEanZlCBqieOOg02bvPTQq1diMN3HH0cbl4jUXEoQtUSfPt7TCeDHP4ZLLvH9uXOji0lEajYliFqiYUN4+mk46ig46yxo2hT239/nb7ruuqijE5GayEItmq+hsLAwFBUVRR1Gxth7b/jmG9/futXHUIiIlGVms0IIhcnOqQRRi919d2KSvzlzoo1FRGoeJYha7Oc/h88+8/0uXWDoUMjNhQULIg1LRGoIJYhaLj8fjj/e9x97DFau1CR/IpIaJYgs8Pzz8Mwzvt+oEdxwA5xxBvz1rz52QkQkGTVSZ5F334XVq305U/BG7D59tEKdSDZTI7UAPlX40Uf7OtcAa9fCzJkwbhyUlCSuGzMGpk2LJEQRySBKEFmoffvE/n//64PqWreGUaPgww99TqdBgyILT0QyRFoThJn1NbN5ZlZsZlcnOW9mNjZ2/gMz61nmXGMzm2Jmn5jZXDP7YTpjzSZlE0RcixZw443e0ymucWN45ZVqC0tEMkzaEoSZ5QD3AicCnYCfmlmncpedCBTEHsOA+8qcuwt4LoTQEegGaNKIKnLAAd5YXb++Px85EmbM8P3Zs33da4A1a+DCC2HLlkjCFJGIpbMEcRhQHEJYEELYCEwCBpa7ZiDwaHBvAY3NrLmZ7QMcATwEEELYGEJYncZYs4qZlyLat/eR1nfc4etLHHywn7/0Um+DuPFG+PRTePnlaOMVkWjUTeN7twSWlnleAhyewjUtgc1AKfCwmXUDZgGXhxC+TV+42eW3v/WSwV57JY717+9tEkcd5aWLo47yeZyuuAK6doVHH/U2ilatoEmT6GIXkeqRzhJEspl/yvepreiaukBP4L4QQg/gW2C7NgwAMxtmZkVmVlRaWro78WaV00+HIUO2PXbzzfDRR4mqp733hm7d4IMPfMzEtGlw2GHwq19Vf7wiUv3SmSBKgFZlnucBy1K8pgQoCSHMjB2fgieM7YQQHgghFIYQCnNzc6sk8GzVoIHPAFtW2Vt6yinw3XcwdSps2FC9sYlI9UtngngHKDCzNmZWHxgCTC93zXRgaKw3U29gTQhheQjhC2CpmXWIXXcMoKVvIjB6NBx5JFx0kT9v3drXvX7qqWjjEpH0S+tIajPrB4wBcoDxIYSbzOxCgBDC/WZmwD1AX2AdcG4IoSj22u7Ag0B9YEHs3H8r+zyNpE6fEHwepz33hB49ICfHq56WLIEVK+CH6oQsUiNVNpJaU23ITps2zQfSTZwId90F77/vJYyTT/bShojUHEoQUqW2bPGeTHXqwOefJ44feii8+CKMGOELFP3tb17SEJHMpbmYpErl5MBPfuLJoXHjRPtEaamXLiZNgsmT4ZNPkr9+yxZv7BaRzKYEIbvk8st9yvC33vLJ/saOhS+/hJdeSlzz9tu+XbsWrrzSq6Juvhn22w9OOimauEUkdapikioxYwb86Ee+37UrLFwIHTp4G8Wtt3qvp9xcL2XEbdwI9epFE6+IOFUxSdp17+49nAA6dfLusEVFvt7EU095Iigthc6dfdAdwLx5kYUrIilQgpAq0bAhnHii7+fnw7XX+kjtY47x6qVTTvFzffp4MgGvnjr/fCUKkUylBCFV5mc/8+3BB3tymDjRezX98Y9weGwWrh/+0KueGjSAO++E8eNhwoToYhaRiilBSJUZNMirlc45Z/tzJ5/sSeKEE6BuXS9JzI1N4D5z5vbXg08cOHasphsXiYoShFSpXr18fER57dt7lVLz5v780ksT595+20dql/enP3lvqVdfTUuoIrIDShASiQEDvJvsoEGwejW0bAmrVm17zbPP+vaZZxLHFi3yEki8C62IpI8ShEQiJ8cH1I0b51OKL1/u606sXu1VSmvWwH/+49c+84w3ZH/2GTz+uJ9/5JEooxfJDkoQEqnmzX2Z01NP9R/9ww7zEsKpp8LmzXDssT4iu2NHb/yOt1vsvXeUUYtkByUIyQj33eeN25995s/jI7IHD972ur/8xbdl54ASkfRQgpCMkJsLDz0EBQWJY3XqeFtFXNOmif0JE6BNGx9898AD8Oc/+xKpPXvCstiyVG++CR9rFRGRXaYEIRkjJ8d/1M8+25/n5XkVVF6eP3/rLR9XMWiQP1+0CC68EC64AIYN895O770HU6Z4r6jTT/dzO7JuXRq+jEgtoAQhGSU317vKgo/IBh95vc8+0Latj8zu0CFx/dSpvj3iCE8eHTrA00/D4sXe8P3GG17yeOGFxGsefBB+8xsvfbz4ok8eWFLiSWXr1ur4liI1Q92oAxApr21b38YTxA03wNKlifEV8b/4r7oKxozxEsa//+3H+veHe+6BJ55IvN9XX8Ett/g8UC1awGWXwfr1XhXVsaNPPf7ee3DeefDtt16KERGVICQDtWvn27IliLLTg599tldHDRvmP/JvvZU4d9FF3gvqiiugfn2fOBC8+ikvD55/3pMDwPTp8NFHvv/++/Cvf3nX2k8+gTvu8GnKRbKZEoRknIICX5CoojUjfvAD7wLbtq1XDzVrljjXvj089hgccgicey7MmQMvv5w4/7vf+TaeAP7xD3/+2GOJa37xC08wf/971X4vkZpG60FIVli2DHr39qqqxo3hiy/gwAN9YF5Z7dtDcbHvX3mlTzQ4apQnpFtuqe6oRdKvsvUg1AYhWaFFCxg40NsnrrsO9tjDn8fHVQAceaSXPOIJYuJETyB//rM/37jRe0atW+fVUBddBGbV/lVEqo1KEJI1vv0W5s/3Fe/AezsNGADDh/vxv/zFq5xS6RoL3kOqT5/0xStSHbSinAiw116J5ADQr59PCDhmjM/3lJubWMxoR+8D8Le/+Xb1au89NWvW7scYQvKZbUWioAQhWcsM+vb1HlFxXbrAAQf4Fny8xO23b/u6iy/2KUAmTfK2jV//2hPMuHGJa44+OrGKXqo2bfKG98aNfY3vVPz5z1WTmESSURuESBkNG/oAu+++85ljzz7bB89dc42PozjjDB8vsWwZTJ7s05THLV7s2xDglVcS+59+6l136+7gX9uMGT46HHwRpR/+sPLrQ/AxHWed5YP/RKqaShAi5Zj5kqjnnOOD8+rW9W63Z54JV18N++/vVVF33OFLqz73nHeNfeklP19Skniva67xwXgFBT5q+9e/3nYQX1kvvJAYDLh8+Y7jXLsWNmyAFSt2+yuLJKUShEgKJk3a/tivfpXYX7rU/4q/7TafFiTuttu8VLJ0KRx3nB+78044/ni46SafVmTFCk86Tz/tpYalSz1BfPopXHKJv/6JJ6BevW0//8svfasEIemiEoRIFejbF1q18v1419muXb0rbXGxt2XsuWdi7qgXXoA//MGXXs3L82lDZs+Gn/7UJyhctsyrt158EZ56CkaP3vbzvvgCXn/d9ytKEC+/7Ot674zXX/dSiQgAIYRa8+jVq1cQicrWrSEccID3Q2rSxJ+XPbdmje+/8UYIJ5wQ768Uwh57+DY3N4Rvvw1h0KDEuQcfDGHw4BD22SeEL78M4eKLQ+jWLXEeQthrr+1jWbjQz51+euLYqlUhrFhRcfwlJf6ae+6pirshNQVQFCr4TVUVk0gVMYMmTbzq56STth1EZ5aoeurTx0dnP/+8r18xbZrvd+nipYzmzf26vfaCoUN9PqnJk/3aZAslffutP+Ldb8HnmYLEoL85c6Cw0EePf/KJDxQsb8EC33744e7dB6k9VMUkUoUuucR/5Mt3jS2vTx94913vuZSXB+ef78utQiJB9Ojh7Q69e8OIET6S++67t00Ecc8/79OBxE2b5tv49OXXX+9VR4sWwdix3kvrT3/ybVy8F5YWWZI4JQiRKjR8uPcuys3d8bU9eviMs+XF19tu3963ZvB//+dtDcOHw49/vP1rTj3Ve1SF4OMp4lOWz5vniePll+HnP4eTT/YJC0eN8sWWRo1KvMeSJb6Nr/td3owZ3s4SbxyX2i+tCcLM+prZPDMrNrOrk5w3MxsbO/+BmfUsdz7HzN4zs3+mM06RqlRnN/9VxZdW/d//TX5+6FA4/PDtjz/xhCeETz7xksGRR3qpYdo0XxPj2GN9Wde6dX30OPhAu02bfD9egli5En77W9iyxZ8vXAjffOMN7CUlienV583za6X2SluCMLMc4F7gRKAT8FMz61TushOBgthjGHBfufOXAxX8PSNSO511lq9NMXRo8vNnnrntGhhlFRf74kfgJQpITDZ49NE+NXq3bomkEJ+McMoUL0HEk9stt3jV1rHH+ujuLl0SvaY++MCru370I5+wMFVbt/pAw/gU65L50lmCOAwoDiEsCCFsBCYBA8tdMxB4NNaY/hbQ2MyaA5hZHtAf0BhRySp16vgP845min3rLZ9m4/TTE91nFy/2BNGwoc8PBfDaa15tdeCB/rxbN9/Gq7IuucTfo6jIG9cfftiPFxX54D/wtouZM33/gw98pPhXX/nYjXXrfMzGqFHbLtn6n//AtdcmVgB88UVvbB8+fJdvjVSzdPZiagksLfO8BChfME52TUtgOTAG+A2wd/pCFKm54tVMkyf7D3P9+j5Yb9Ei7/HUsqUfW7/epwmJJ5x4gjjppMSEg+DVRR07elvFjBnwwAOJc6ee6kmodWtPEPHR4OvXewP5uHGeAHJyvLTRvn1iptvu3T0B3X+/Pz/44OTfZ/16H8Fu5u0cK1Yk5sSSaKSzBJHs75/y81QmvcbMBgArQgg7nIbMzIaZWZGZFZWWlu5KnCI1Xp063oA8ezbsu68nijp1/AcdEltIzGhbWJjoMXXAAV5ddNVV/vyWW7yh++CD/Yd+yhSvVjr3XPjsM5gwwau69tvPE0e8vWL0aDjtNE8K8Qb4xx/3wX7xlf1KS30A3913J2auLS313l/33OPP8/O3nXl31ixv+J8/vwpvmuxQOhNECdCqzPM8YFmK1/QBTjazRXjV1NFm9tdkHxJCeCCEUBhCKMxNpeuISC0Vr0L6yU+8JACJdb0POihx3Q9+4HNCnX46dOjgx55+2ksBTZr48/328zaGp5/2ZADeuH3BBT7b7Lp13pA9cKD/+McnGYRE1dWll3pyeeIJ7z21Zo2XEJYt8yVeL7vMu/rCtku/fvddYjT3xo2+feopL+E89ZQ/79/fpyoJYdtxG3fe6Q3xO5Lq9CQheALM2vXJKxpBt7sPvPpqAdAGqA+8D3Qud01/4Fm8JNEbeDvJ+/wv8M9UPlMjqSWbtW/vI6GffDJxbNgwP3bzzclfc+GFIdStG8L69al/zr//HcKf/uT7Tz2VGNF97LEh3H23jxpfvDiEjRtDOO20bUd9n3GGb3/2M9+OHRvCM8+E0LOnP+/TJ4Tnnktcv3Chf84xx/jz/v1DWLs2cf6Pf/TtCy/4dfHjy5aFsG5d8vjHjfNrPvlkx9/1ySf92lGjkp+v6DNqEioZSZ3WqS+AfsCnwHzg2tixC4ELY/uG93SaD3wIFCZ5DyUIkRQMHer/oletShy7+WY/NmFC8tcsXBjCP/+565/53XeJH+VHH93+/NtvbzutyIMP+jY+JUn8kZPjj8aNQxg4MHH8jTc8gUAIZj6tyLPPJs7HpykZMCCE++9PHB84MIR69UKYMcPjWLw4hC++8P14Ir3vvhBef90T2pVXhjBrVgiLFoVwyy2JaVKGDPFrr712++/2wQfbfkZNFVmCqO6HEoRks2+/DWHOnG2PTZzo/8pffz19n/vww/4Zn35a8TUjRvicUi++uG1igBD69Qvhq6+8NBE/duqpvr3hhsSx4cN927Xr9u9R0ePqq0O45hrfr1s3cT/KPubP923btiEcdVTidffc4wkJQhg5cvvvdOutiRhrMiUIkSy1bp3/0G3enP7PScXcudv/QMeT10svJY4tWODbfff17axZ/h3Klzxg+2qs449P7JuFUKeOJ6eWLbef6LBsNVWHDomqrvKPs89OfIf1673k1Ldv+L70UtY334Twy1+G8PnnVXFn06+yBKGpNkRqsYYNfZxD2WVV0/U5qWjb1rvn5uT4WhmDB3tjOMD//A/88Y++FkZ+vr/nmjU+BXrPnv6aG27whu8rr/TX1KnjA/2eey7xGfEBguA/702aeBfbTp3g/fe3j2n8eN82aVLxfSrbQbJjR2/of+MNf/7229uuI/7SSx7TxImJAYmVWbDAu/juirKfmxYVZY6a+FAJQqRmSKVEE28rSFY99uqr4ftqobg99wzfV3XFq6RGjAhh6VI//4tfJEoETZtuX0po3z758S5dQigs9Pco20AOIZx0km/vuCMRx7XXbvs5xcU+zfo333h13JVXJjoFrFrlcZ933vbf8fbbQ5g5s+L7M3Omf8aMGdtOLb+zUBWTiNQ0/ft7lU+yH7/iYv/1OuGExLHS0hCWLPH9qVNDWL1629fceKO/pnHjEI44YvtEEH/st19iv0GDEM46K4T8fH+Pl19OnGvXzhNG//7+vKTErzn22G3f7/DDfdu5c+LYM8/4tWPGhO8b6T/7zBu+b7rJY4QQDjmk4vsTrxqDEHr33rV7HELlCUJVTCKSkR57zOekSjblSMuWvi0oSBxr1iyxqt8pp/iAwbLigwXbt08MEBwwwKdSv/XWxHUPPeTra5hBu3a+HOzKlfDkkz6fVfz9b70VGjVKvPbxx31uq3fe8fEe4PNqvf2278+Zk/iMuXN9rMfdd3uVVU6Of5euXX16ktWr/bry3+GrrxLnvvrKt23b+tK1qVRn7SwtGCQiGSk+aC+ZBg28XeHII1N/v3iCaNcukSDuuMNHi0+enLiuRw8f1d2uHRxyiCeeb76Byy/384cdlpj7Cnwak4YNffDh2LHebvLII4npRs45x+epmjnTr9tzT08W113nI8Off95n3B03LvGeder4WiDxJAA+Ev2YY3wSxdtu8/fIz0/v6HIlCBGpkS64YOeuL1uCKCjwqTviI8zjo9Dr1k1cN326rwL49NP+fMkSX2TpzDO3fV8zuOYa+P3vfbJEM+jXL7EmyHHHwddf+xQkhx7qCz7FG8aHDYPjj/c5p4qLfa1y8FLJl1/Cvfd6JdLmzb6oFPgkjfHEmGxtkKqkKiYRyQqtWsHVV/sP/NChvrZFvCqoWzfvKfXKK4nrDznEq7IaN04cGzrUq5XKGzXKZ7QFL2GUn/Wne/fE5zRr5vt9+iTmnmre3EsS++/vzwsKPHlt2ODVW48+6tOZTJq0/XdKJyUIEckKZj4JYadOvl92Nb999/UJAZP9Rd6jh1d3Pf98IqEkU1Dg80/9+tfbn2vb1tsjzjzT56UCryKrV2/b6w45JPFe8dJNcTHceKN3rR082BdwGjzYz+Xlpfbdd5WqmEREKlFQsG1bQGXGjk1+3Az+GptuNASfUDHZ2uKdOvn6HW3bJtYLv+oqLz3ce6+/T3ym28mTt11TPB1UghARqUZmyZMDwMiR3nurYUNvJN9zT1/Jr18/OPHExHX9+vk23qsqXVSCEBHJEAUFia67++7ra3I89JBPrV62u2+PHt7wHZ9aPV2UIEREMlT37j5WIpl0JwdQFZOIiFRACUJERJJSghARkaSUIEREJCklCBERSUoJQkREklKCEBGRpJQgREQkKfMFhWoHMysFFu/iy5sBK6swnHRSrOmhWNNDsaZPVcTbOoSQm+xErUoQu8PMikIIhVHHkQrFmh6KNT0Ua/qkO15VMYmISFJKECIikpQSRMIDUQewExRreijW9FCs6ZPWeNUGISIiSakEISIiSSlBiIhIUlmfIMysr5nNM7NiM7s66njKM7NFZvahmc02s6LYsf3M7F9m9lls2yTC+Mab2Qoz+6jMsQrjM7NrYvd6npmdkAGxXm9mn8fu72wz6xd1rGbWysxeMbO5ZjbHzC6PHc/U+1pRvJl4bxuY2dtm9n4s1tGx4xl3byuJtfruawghax9ADjAfaAvUB94HOkUdV7kYFwHNyh37I3B1bP9q4LYI4zsC6Al8tKP4gE6xe7wH0CZ273MijvV64Iok10YWK9Ac6Bnb3xv4NBZPpt7XiuLNxHtrQKPYfj1gJtA7E+9tJbFW233N9hLEYUBxCGFBCGEjMAkYGHFMqRgI/CW2/xdgUFSBhBBeA74qd7ii+AYCk0II34UQFgLF+H+DalFBrBWJLNYQwvIQwrux/bXAXKAlmXtfK4q3IlHe2xBC+Cb2tF7sEcjAe1tJrBWp8lizPUG0BJaWeV5C5f9jRyEAL5jZLDMbFjt2QAhhOfg/TmD/yKJLrqL4MvV+DzezD2JVUPGqhYyI1czygR74X48Zf1/LxQsZeG/NLMfMZgMrgH+FEDL23lYQK1TTfc32BGFJjmVav98+IYSewInAJWZ2RNQB7YZMvN/3Ae2A7sBy4I7Y8chjNbNGwBPAiBDC15VdmuRYtd/XJPFm5L0NIWwJIXQH8oDDzOzQSi7PxFir7b5me4IoAVqVeZ4HLIsolqRCCMti2xXAP/Ai45dm1hwgtl0RXYRJVRRfxt3vEMKXsX+EW4E/kyiSRxqrmdXDf2wnhBCmxg5n7H1NFm+m3tu4EMJq4FWgLxl8b2HbWKvzvmZ7gngHKDCzNmZWHxgCTI84pu+Z2V5mtnd8Hzge+AiP8ZzYZecA06KJsEIVxTcdGGJme5hZG6AAeDuC+L4X/1GIOQW/vxBhrGZmwEPA3BDCnWVOZeR9rSjeDL23uWbWOLbfEDgW+IQMvLcVxVqt97U6WuMz+QH0w3tdzAeujTqecrG1xXslvA/MiccHNAVeAj6LbfeLMMaJeDF3E/4XzPmVxQdcG7vX84ATMyDWx4APgQ9i/8CaRx0r8GO8auADYHbs0S+D72tF8Wbive0KvBeL6SPg97HjGXdvK4m12u6rptoQEZGksr2KSUREKqAEISIiSSlBiIhIUkoQIiKSlBKEiIgkpQQhUsXMbISZ7Rl1HCK7S91cRaqYmS0CCkMIK6OORWR31I06AJGaLDbCfTI+rUEO8HegBfCKma0MIRxlZscDo/FpmOcD54YQvoklkseBo2Jvd2YIobi6v4NIRVTFJLJ7+gLLQgjdQgiHAmPw+W+OiiWHZsDvgGODT7pYBPyqzOu/DiEcBtwTe61IxlCCENk9HwLHmtltZvY/IYQ15c73xhdyeTM2bfM5QOsy5yeW2f4w3cGK7AxVMYnshhDCp2bWC5976BYze6HcJYbP4//Tit6ign2RyKkEIbIbzKwFsC6E8FfgdnxJ07X40psAbwF9zKx97Po9zezgMm9xRpntjOqJWiQ1KkGI7J4uwP8zs634LLEX4VVFz5rZ8lg7xM+BiWa2R+w1v8NnEAbYw8xm4n+sVVTKEImEurmKRETdYSXTqYpJRESSUglCRESSUglCRESSUoIQEZGklCBERCQpJQgREUlKCUJERJL6/8RL/V+ijww7AAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 网络初始化\n",
    "net = Net()\n",
    "# [w1, b1, w2, b2, w3, b3]\n",
    "optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)\n",
    "\n",
    "train_loss = []\n",
    "\n",
    "# 训练（求导，更新）\n",
    "# 迭代3次\n",
    "for epoch in range(3):\n",
    "    # 迭代整个数据集\n",
    "    for batch_idx, (x, y) in enumerate(train_loader):\n",
    "        # x: [b, 1, 28, 28], y: [512]\n",
    "        # [b, 1, 28, 28] => [b, 784]\n",
    "        x = x.view(x.size(0), 28 * 28)\n",
    "        # => [b, 10]\n",
    "        out = net(x)\n",
    "        # [b, 10]\n",
    "        y_onehot = one_hot(y)\n",
    "        # loss = mse(out, y_onehot)\n",
    "        loss = F.mse_loss(out, y_onehot)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        # w' = w - lr * grad\n",
    "        optimizer.step()\n",
    "        \n",
    "        train_loss.append(loss.item())\n",
    "        \n",
    "        if batch_idx % 10 == 0:\n",
    "            print(epoch, batch_idx, loss.item())\n",
    "\n",
    "plot_curve(train_loss)\n",
    "# we get optimal [w1, b1, w2, b2, w3, b3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "51ba3c0e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test acc: 0.8802\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZUAAAELCAYAAAARNxsIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAYhUlEQVR4nO3dfZAU1bnH8d+BICDeKBJUlAgBb1AkvAmKBONLUF4ERAGhIIloRTQl0SoUEsVc38sEb5mYRCBJFRfFxOQKaoEgwk3xouVLCVdREeIFw4vlkrAiyIJEWM79Y8e2T8sMM7Nn+mX3+6ma8jx7erqfXY7z7OnuPW2stQIAwIcmSScAAGg4KCoAAG8oKgAAbygqAABvKCoAAG8oKgAAbygqAABvMlVUjDFbjDEDPexnojHmpRLfs94YUxN6HTLGLKpvLohHwmPnP40x/2eM2WuM2WiM+UF980B8Eh47VxtjXjbG7DfGrKxvDnH4StIJZIW19uzP28YYI2mzpKeSywgZsk/ScEnvSeoraakxZpO19uVk00IG7JL0K0lnSrok2VSKZK3NxEvSPEmHJX0qqUbStNzX+0l6WdJuSeskXRR6z0RJ70vaK+nvkiZIOkvSAUm1uf3sLiOXC3PvbZX0z4VXtsZObt8LJd2a9M+FV3bGjqQfSlqZ9M+jqFyTTqDEH+wWSQND8WmSPpI0VHWn8i7NxW0ltZL0iaQuuW3bSTo79I/+UmTf4yW9VWQecyTNTfrnwSuTY6elpCpJg5P+mfDKztjJUlHJ1DWVI/iepCXW2iXW2sPW2uWS1qjuH1uq+w2jmzGmpbW2ylq7Pt+OrLV/stZ2P9oBjTHHShotaW7900eCYh87ObNV95vtC/VJHolKauxkQtaLSgdJY4wxuz9/SRogqZ21dp+ksZJulFRljFlsjDnTwzGvUt15zlUe9oXkxD52jDEPSeom6Wqb+/UTmZTE505mZK2oRP9H3C5pnrX2hNCrlbX255JkrX3BWnup6qagGyX9Ic9+SnGNpMf5UMicRMeOMeYeSUMkXWat/aS8bwEJScPnTmZkraj8Q1KnUPyEpOHGmEHGmKbGmBbGmIuMMe2NMScbY0YYY1pJ+pfqLo7VhvbT3hhzTCkHN8a0l3SxpMfq/60gZomNHWPM7ao7d36ptfYjP98OYpTk2GlqjGmhujt1m+SO1czPt1UhSV/UKeUl6QpJ21R3x8Vtua+dp7pTUbsk7ZS0WNLpqvstYZWkPbntV0rqmnvPMbntdkmqzn1tgqT1Rzn+7ZJeTPrnwCtbY0d1v6F+/gHz+euOpH8mvDIxdibmxk/4NTfpn0mhl8klDgBAvWXt9BcAIMUoKgAAbygqAABvKCoAAG8oKgAAb0papdgYw61iKWStNUnnUAjjJrWqrbVtk06iEMZOauUdO8xUgMZra9IJILPyjh2KCgDAG4oKAMAbigoAwBuKCgDAG4oKAMAbigoAwBuKCgDAG4oKAMCbkv6iHmgIbrvtNidu2bKlE3fv3j1ojx49uuC+Zs2aFbRfeeUVp2/evHnlpghkFjMVAIA3FBUAgDcUFQCANyU9o54VQ9OJVYqP7i9/+UvQPtp1knJt3rzZiQcOHOjE27Ztq8hx62GttbZP0kkUkoaxE4dvfvObTrxx40YnvuWWW4L2b37zm1hyOoq8Y4eZCgDAG4oKAMAbbilGgxQ+3SWVdsorfOrhhRdecPo6derkxMOHDw/anTt3dvomTJjgxA8++GDROaBx6dWrlxMfPnzYiT/44IM406kXZioAAG8oKgAAbygqAABvuKaCBqFPH/fuxiuvvDLvtuvXr3fiESNGOHF1dXXQrqmpcfqOOeYYJ3711VeDdo8ePZy+Nm3aFMgY+ELPnj2deN++fU78zDPPxJhN/TBTAQB4Q1EBAHiTitNf4ds9r7/+eqfvww8/dOIDBw4E7T/+8Y9O344dO5x406ZNvlJEyrVr186JjXEXGQif8ho0aJDTV1VVVfRxbr31Vifu2rVr3m0XL15c9H7R+HTr1i1oT5482enL8grXzFQAAN5QVAAA3lBUAADepOKayowZM4J2x44di37fDTfc4MR79+514uito3EIL6cQ/r4kac2aNXGn02gsWrTIic844wwnDo+NXbt2lX2ccePGOXGzZs3K3hcatzPPPDNot2rVyumLLjOUJcxUAADeUFQAAN5QVAAA3qTimkr4b1O6d+/u9G3YsMGJzzrrrKDdu3dvp++iiy5y4n79+gXt7du3O31f//rXi87v0KFDTrxz586gHf37iLDok/64phKfrVu3etnP1KlTnTj6hL6w1157rWAMhE2bNi1oR8drlj8rmKkAALyhqAAAvEnF6a+//vWvR2wfydKlS/P2tW7d2onDK3+uXbvW6evbt2/R+YWXhpGk9957L2hHT8+deOKJQXvz5s1FHwPpMWzYsKB97733On3RVYr/+c9/Bu3bb7/d6du/f38FskNWRf9cIryydvgzRfryKsVZwkwFAOANRQUA4A1FBQDgTSquqfjy8ccfO/GKFSvybnu0azeFjBo1KmhHr+O8/fbbQTvLSy00ZuFz3dFrKFHhf+NVq1ZVLCdk34UXXpi3L/xnClnHTAUA4A1FBQDgDUUFAOBNg7qmUiknnXSSE8+cOTNoN2ni1uXw3zXUZ4l1xOfZZ5914ssuuyzvto8//rgT33nnnZVICQ3Qt771rbx90cdkZBkzFQCANxQVAIA3nP4qwk033eTEbdu2DdrR25j/9re/xZITyhddWbp///5O3Lx586BdXV3t9N1///1OXFNT4zk7NBThVdIl6dprr3XiN954I2gvX748lpziwEwFAOANRQUA4A1FBQDgDddUjuDb3/62E//0pz/Nu+3IkSOd+J133qlESvBowYIFTtymTZu82z7xxBNOzOMMUKyBAwc6cfixGJL7GI/o4zWyjJkKAMAbigoAwBuKCgDAG66pHMHQoUOduFmzZk4cXjb/lVdeiSUn1M+IESOCdu/evQtuu3LlyqB91113VSolNHA9evRwYmutE8+fPz/OdGLDTAUA4A1FBQDgDae/clq2bBm0Bw8e7PR99tlnThw+JXLw4MHKJoayRG8TvuOOO4J29HRm1Jtvvhm0WYYFpTjllFOC9gUXXOD0RZdweuaZZ2LJKW7MVAAA3lBUAADeUFQAAN5wTSVn6tSpQbtXr15OX3g5BUl6+eWXY8kJ5bv11luduG/fvnm3jT75kduIUa6JEycG7egTY59//vmYs0kGMxUAgDcUFQCANxQVAIA3jfaayuWXX+7EP/vZz4L2J5984vTde++9seQEf6ZMmVL0tpMnT3Zi/jYF5erQoUPevuijxxsqZioAAG8oKgAAbxrN6a/osh2//vWvnbhp06ZBe8mSJU7fq6++WrnEkLjoE/nKXXpnz549BfcTXh7m+OOPz7ufE044wYlLOZVXW1vrxD/5yU+C9v79+4veD8ozbNiwvH2LFi2KMZPkMFMBAHhDUQEAeENRAQB406CvqYSvk0SXWvnGN77hxJs3bw7a4duL0fC99dZbXvbz1FNPOXFVVZUTn3zyyUF77NixXo55NDt27AjaDzzwQCzHbEwGDBjgxOGl7xsrZioAAG8oKgAAbxr06a/OnTsH7XPOOafgtuHbNsOnwpBN0dvCr7jiioofc8yYMWW/99ChQ0H78OHDBbdduHBh0F6zZk3BbV988cWyc8LRXXnllU4cPuX+xhtvOH2rV6+OJaekMVMBAHhDUQEAeENRAQB406CuqURXCF22bFnebcNPepSk5557riI5IRlXXXWVE0+bNi1oh5dLOZqzzz7biUu5FXjOnDlOvGXLlrzbLliwIGhv3Lix6GMgXscee6wTDx06NO+28+fPd+LoEjoNFTMVAIA3FBUAgDcUFQCANw3qmsqkSZOc+PTTT8+77apVq5zYWluRnJAOM2bM8LKf8ePHe9kPsin6OIPo0xzDf0P0yCOPxJJT2jBTAQB4Q1EBAHiT6dNf0RVCf/zjHyeUCYDGIHr6q3///gllkl7MVAAA3lBUAADeUFQAAN5k+prKBRdc4MTHHXdc3m2jy9nX1NRUJCcAaMyYqQAAvKGoAAC8oagAALzJ9DWVo1m3bl3Q/u53v+v07dq1K+50AKDBY6YCAPCGogIA8MaUsjqvMYalfFPIWmuSzqEQxk1qrbXW9kk6iUIYO6mVd+wwUwEAeENRAQB4Q1EBAHhT6i3F1ZK2ViIRlK1D0gkUgXGTTowdlCvv2CnpQj0AAIVw+gsA4A1FBQDgDUUFAOANRQUA4A1FBQDgDUUFAOANRQUA4A1FBQDgDUUFAOANRQUA4A1FBQDgDUUFAOANRQUA4A1FBQDgTaaKijFmizFmoIf9TDTGvFTie5obY+YYYz4xxuwwxkypbx6IT5JjJ/TeE40xO8t9P5KR8OfO1caYl40x+40xK+ubQxxKfUhXY3a3pH9X3cNpTpG0whjzrrV2aaJZIUt+IWmDMvbLHBK1S9KvJJ0p6ZJkUylOZga3MWaepNMlLTLG1BhjpuW+3i9XyXcbY9YZYy4KvWeiMeZ9Y8xeY8zfjTETjDFnSZot6fzcfnYXmcIPJN1nrf3YWrtB0h8kTfT3HaJSUjB2ZIw5X1I3Sf/l8VtDhSU9dqy1/2Ot/W9JH/r+3irGWpuZl6QtkgaG4tMkfSRpqOoK5KW5uK2kVpI+kdQlt207SWfn2hMlvRTZ93hJb+U5bmtJVtLJoa+NlvR20j8TXukeO7n+ppL+V9I5R3o/r3S/khw7oe1+KGll0j+LYl6Zmank8T1JS6y1S6y1h621yyWtUd0/tiQdltTNGNPSWltlrV2fb0fW2j9Za7vn6T4u9989oa/tkfRv9cwfyYlr7EjSzZJes9au9ZY9khTn2MmcrBeVDpLG5Kagu3NTygGS2llr90kaK+lGSVXGmMXGmDPLPE5N7r9fDX3tq5L2lrk/JC+WsWOMOVV1RWW6p7yRvLg+dzIpa0XFRuLtkuZZa08IvVpZa38uSdbaF6y1l6puCrpRdddBjrSfwge19mNJVZJ6hL7cQ1Le30CQOomMHUnn5vbxrjFmh6RHJJ2bu4OwadnfDeKU1NjJpKwVlX9I6hSKn5A03BgzyBjT1BjTwhhzkTGmvTHmZGPMCGNMK0n/Ut1soza0n/bGmGNKOPbjku40xrTO/eZxvaS59f6OEJekxs7zkjpK6pl7/YekNyT1tNbW5n0X0iSxz53P96+6O3Wb5I7VzM+3VSFJX9Qp5SXpCknbJO2WdFvua+dJWqW6W+92Slqsurs12uW+vie3/UpJXXPvOSa33S5J1bmvTZC0vsCxm0uao7qLcP+QNCXpnwevbIydSB4TxYX6TL0S/tyZqLoZTvg1N+mfSaGXySUOAEC9Ze30FwAgxSgqAABvKCoAAG8oKgAAbygqAABvSlql2BjDrWIpZK01SedQCOMmtaqttW2TTqIQxk5q5R07zFSAxmtr0gkgs/KOHYoKAMAbigoAwBuKCgDAG4oKAMAbigoAwBuKCgDAG4oKAMAbigoAwBuKCgDAG4oKAMAbigoAwBuKCgDAm5JWKc6aVq1aBe2HHnrI6bvhhhuceO3atUF7zJgxTt/Wray7BwDFYKYCAPCGogIA8KZBn/5q165d0L7++uudvsOHDzvxOeecE7SHDRvm9D366KMVyA5J6d27txM//fTTTtyxY8eK53DZZZc58YYNG4L29u3bK358pMvw4cOdeOHChU48efLkoD179mynr7a2tnKJlYGZCgDAG4oKAMAbigoAwJsGdU2lbdu2TvzYY48llAnSbNCgQU7cvHnz2HOInkO/7rrrgva4cePiTgcJaNOmTdCeOXNmwW1/+9vfBu05c+Y4fZ9++qnfxOqJmQoAwBuKCgDAm0yf/rr55pudeOTIkU587rnnlrXf73znO07cpIlbe9etWxe0V69eXdYxEK+vfOWLoT506NAEM6kTXsFBkqZMmRK0wytBSNK+fftiyQnxCn/OtG/fvuC2Tz75ZNA+cOBAxXLygZkKAMAbigoAwBuKCgDAm0xfU/nlL3/pxNGlV8p11VVXFYzDqxaPHTvW6YueK0c6XHzxxUH7/PPPd/pmzJgRdzpq3bq1E3ft2jVoH3vssU4f11Qahuit69OnTy/6vfPmzQva1lpvOVUCMxUAgDcUFQCANxQVAIA3ppTzc8aYxE/mLVmyJGgPGTLE6avPNZWPPvooaNfU1Dh9HTp0KHo/TZs2LTuHcllrTewHLUES46Zbt25OvHLlyqAd/reW3MceSF/+96+EcD6SNGDAgKAdfmSDJO3cubNSaay11vap1M59SMNnji99+rg/6tdffz3vtocOHXLiZs2aVSSnesg7dpipAAC8oagAALxJ/S3FF154oRN36dIlaEdPd5Vy+iv69LRly5YF7T179jh9l1xyiRMXuhXwRz/6UdCeNWtW0fnArzvvvNOJw0ufDB482OmL43SXJJ144olBOzqufd0Oj/QaNWpU0duGP4+yhpkKAMAbigoAwBuKCgDAm9RdU+nYsaMT//nPf3bir33ta0XvK7ycyoIFC5y+e+65x4n3799f1H4kadKkSUE7+rTJ8JIfLVq0cPrCT2+TpIMHD+Y9JkozevRoJ44ub79p06agvWbNmlhyigpfi4teQwnfYrx79+6YMkKcoo/UCPvss8+cuJQlXNKGmQoAwBuKCgDAG4oKAMCb1F1TCT/2VSrtGsqqVauceNy4cUG7urq67Jyi11QefPDBoP3www87feFly6NLqi9cuNCJN2/eXHZOcI0ZM8aJo8vHz5w5M850JH35+uCECROCdm1trdN3//33B22utTUM/fv3LxiHRR9v8Oabb1YipVgwUwEAeENRAQB4k7rTX6WI3hp63XXXOXF9TnkVEj6NFT6lIUl9+/atyDHxZccff3zQ7tevX8Ftk1gyJ3zrueSeyt2wYYPTt2LFilhyQnxK+SxoSEs6MVMBAHhDUQEAeENRAQB4k/prKk2a5K975513XoyZfMGYLx60GM2vUL533323E3//+9/3mldj07x586B92mmnOX1PPvlk3Ol8SefOnfP2vfPOOzFmgiREn/QYFV6Oh2sqAAAcAUUFAOANRQUA4E3qrqnceOONTpzGx6wOHz48aPfq1cvpC+cbzT16TQX1s3fv3qAdXdaie/fuThx+lO+uXbsqks9JJ53kxNHl+MNeeumliuSAZA0YMCBojx8/vuC24ceWf/DBBxXLKW7MVAAA3lBUAADepO70V/jUUlKiT3Ps2rWrE99xxx1F7Wfnzp1OzOqzfn366adBO7ri86hRo5x48eLFQTu6snQpunXr5sSdOnUK2tFVia21efeTxtO6qL82bdoE7UJ/XiBJy5cvr3Q6iWCmAgDwhqICAPCGogIA8CZ111TSYPr06U580003Ff3eLVu2BO1rrrnG6du2bVu98kJ+d911lxOHl9KRpMsvvzxo12cJl+jjFMLXTUp5SuncuXPLzgHpVeg28vCyLJL0u9/9rsLZJIOZCgDAG4oKAMAbigoAwBuuqeQsWbIkaHfp0qXs/bz77rtBm6U44rNx40Ynvvrqq524Z8+eQfuMM84o+zjz58/P2/fYY485cfRR02Hhv7FBdrVv396JCy3NEl2KJfo49IaCmQoAwBuKCgDAm9Sd/oreClpoqYMhQ4YU3Nfvf//7oH3qqacW3DZ8nPosoZGGZWbwZeFVjKMrGvvy/vvvF71tdLkXngSZTf3793fiQp9Xzz77bIWzSQdmKgAAbygqAABvKCoAAG9Sd01l1qxZTjxjxoy82z733HNOXOhaSCnXSUrZdvbs2UVvi4Ytej0wGodxDaVhCC91HxVd0ueRRx6pdDqpwEwFAOANRQUA4E3qTn89/fTTTjx16lQnjj6VsRKiT2zcsGGDE0+aNCloV1VVVTwfZEP0SY+FnvyIhmHQoEF5+6Krku/Zs6fS6aQCMxUAgDcUFQCANxQVAIA3qbumsnXrViceN26cE48cOTJo33LLLRXJ4YEHHnDiRx99tCLHQcPSokWLgv2sTJx9zZo1c+LOnTvn3fbAgQNOfPDgwYrklDbMVAAA3lBUAADeUFQAAN6k7ppK1OrVq/PGy5Ytc/rCfz8iucvQL1y40OkLL4svuUtqhJ/eCBTr2muvdeLdu3c78X333RdjNqiE6BJO0ac3hh9psGnTplhyShtmKgAAbygqAABvUn/6q5ClS5cWjIE4vf7660788MMPO/GKFSviTAcVUFtb68TTp0934vDSPGvXro0lp7RhpgIA8IaiAgDwhqICAPDGlLI8tzGGtbxTyFqb/xGDKcC4Sa211to+SSdRCGMntfKOHWYqAABvKCoAAG8oKgAAbygqAABvKCoAAG8oKgAAbygqAABvKCoAAG8oKgAAbygqAABvSl36vlrS1kokgrJ1SDqBIjBu0omxg3LlHTslrf0FAEAhnP4CAHhDUQEAeENRAQB4Q1EBAHhDUQEAeENRAQB4Q1EBAHhDUQEAeENRAQB48//a8iADv4lTtAAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 432x288 with 6 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# step4：准确度测试\n",
    "total_correct = 0\n",
    "# 从测试集test_loader测试\n",
    "for x,y in test_loader:\n",
    "    x  = x.view(x.size(0), 28*28)\n",
    "    out = net(x)\n",
    "    # out: [b, 10] => pred: [b]\n",
    "    # 求最大值所在的索引\n",
    "    pred = out.argmax(dim=1)\n",
    "    # 正确值\n",
    "    # item()：tensor转为数值类型\n",
    "    correct = pred.eq(y).sum().float().item()\n",
    "    total_correct += correct\n",
    "\n",
    "total_num = len(test_loader.dataset)\n",
    "acc = total_correct / total_num\n",
    "# 准确度\n",
    "print('test acc:', acc)\n",
    "\n",
    "x, y = next(iter(test_loader))\n",
    "out = net(x.view(x.size(0), 28*28))\n",
    "pred = out.argmax(dim=1)\n",
    "plot_image(x, pred, 'test')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8257aea5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
